/**
 * AVX implementation of sin, cos, sincos, exp and log
 *
 * Based on "sse_mathfun.h", by Julien Pommier
 * http://gruntthepeon.free.fr/ssemath/
 *
 * Copyright (C) 2012 Giovanni Garberoglio
 * Interdisciplinary Laboratory for Computational Science (LISC)
 * Fondazione Bruno Kessler and University of Trento
 * via Sommarive, 18
 * I-38123 Trento (Italy)
 *
 * This software is provided 'as-is', without any express or implied
 * warranty.  In no event will the authors be held liable for any damages
 * arising from the use of this software.
 *
 * Permission is granted to anyone to use this software for any purpose,
 * including commercial applications, and to alter it and redistribute it
 * freely, subject to the following restrictions:
 *
 * 1. The origin of this software must not be misrepresented; you must not
 *    claim that you wrote the original software. If you use this software
 *    in a product, an acknowledgment in the product documentation would be
 *    appreciated but is not required.
 * 2. Altered source versions must be plainly marked as such, and must not be
 *    misrepresented as being the original software.
 * 3. This notice may not be removed or altered from any source distribution.
 *
 * (this is the zlib license)
 *
 * --------------------------------------------------------------------------
 * \file dnn/src/x86/elemwise/avx_util/avx_mathfun.cpp
 *
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *
 * This file has been modified by Megvii ("Megvii Modifications").
 * All Megvii Modifications are Copyright (C) 2014-2021 Megvii Inc. All rights reserved.
 *
 * --------------------------------------------------------------------------
 */
#include "./avx_mathfun.h"
#if !defined(__clang__)
#pragma GCC target("avx2")
#endif

#ifndef __AVX2__
#define __AVX2__
#endif

#define ALIGN32_BEG
#define ALIGN32_END __attribute__((aligned(32)))

namespace megdnn {
namespace x86 {
namespace detail {

/* __m128 is ugly to write */
typedef __m256 v8sf;   // vector of 8 float (avx)
typedef __m256i v8si;  // vector of 8 int   (avx)
typedef __m128i v4si;  // vector of 8 int   (avx)

#define _PI32AVX_CONST(Name, Val) \
    static const ALIGN32_BEG int _pi32avx_##Name[4] ALIGN32_END = {Val, Val, Val, Val}

//_PI32AVX_CONST(1, 1);
//_PI32AVX_CONST(inv1, ~1);
//_PI32AVX_CONST(2, 2);
//_PI32AVX_CONST(4, 4);

/* declare some AVX constants -- why can't I figure a better way to do that? */
#define _PS256_CONST(Name, Val)                                                        \
    static const ALIGN32_BEG float _ps256_##Name[8] ALIGN32_END = {Val, Val, Val, Val, \
                                                                   Val, Val, Val, Val}
#define _PI32_CONST256(Name, Val)                                    \
    static const ALIGN32_BEG int _pi32_256_##Name[8] ALIGN32_END = { \
            Val, Val, Val, Val, Val, Val, Val, Val}
#define _PS256_CONST_TYPE(Name, Type, Val)                                            \
    static const ALIGN32_BEG Type _ps256_##Name[8] ALIGN32_END = {Val, Val, Val, Val, \
                                                                  Val, Val, Val, Val}

_PS256_CONST(1, 1.0f);
_PS256_CONST(0p5, 0.5f);
/* the smallest non denormalized float number */
_PS256_CONST_TYPE(min_norm_pos, int, 0x00800000);
//_PS256_CONST_TYPE(mant_mask, int, 0x7f800000);
_PS256_CONST_TYPE(inv_mant_mask, int, ~0x7f800000);

_PS256_CONST_TYPE(sign_mask, int, (int)0x80000000);
_PS256_CONST_TYPE(inv_sign_mask, int, ~0x80000000);

_PI32_CONST256(0, 0);
_PI32_CONST256(1, 1);
_PI32_CONST256(inv1, ~1);
_PI32_CONST256(2, 2);
_PI32_CONST256(4, 4);
_PI32_CONST256(0x7f, 0x7f);

_PS256_CONST(cephes_SQRTHF, 0.707106781186547524);
_PS256_CONST(cephes_log_p0, 7.0376836292E-2);
_PS256_CONST(cephes_log_p1, -1.1514610310E-1);
_PS256_CONST(cephes_log_p2, 1.1676998740E-1);
_PS256_CONST(cephes_log_p3, -1.2420140846E-1);
_PS256_CONST(cephes_log_p4, +1.4249322787E-1);
_PS256_CONST(cephes_log_p5, -1.6668057665E-1);
_PS256_CONST(cephes_log_p6, +2.0000714765E-1);
_PS256_CONST(cephes_log_p7, -2.4999993993E-1);
_PS256_CONST(cephes_log_p8, +3.3333331174E-1);
_PS256_CONST(cephes_log_q1, -2.12194440e-4);
_PS256_CONST(cephes_log_q2, 0.693359375);

#ifndef __AVX2__

typedef union imm_xmm_union {
    v8si imm;
    v4si xmm[2];
} imm_xmm_union;

#define COPY_IMM_TO_XMM(imm_, xmm0_, xmm1_)           \
    {                                                 \
        imm_xmm_union u __attribute__((aligned(32))); \
        u.imm = imm_;                                 \
        xmm0_ = u.xmm[0];                             \
        xmm1_ = u.xmm[1];                             \
    }

#define COPY_XMM_TO_IMM(xmm0_, xmm1_, imm_)           \
    {                                                 \
        imm_xmm_union u __attribute__((aligned(32))); \
        u.xmm[0] = xmm0_;                             \
        u.xmm[1] = xmm1_;                             \
        imm_ = u.imm;                                 \
    }

#define AVX2_BITOP_USING_SSE2(fn)                                                   \
    MEGDNN_ATTRIBUTE_TARGET("avx2") static inline v8si _mm256_##fn(v8si x, int a) { \
        /* use SSE2 instruction to perform the bitop AVX2 */                        \
        v4si x1, x2;                                                                \
        v8si ret;                                                                   \
        COPY_IMM_TO_XMM(x, x1, x2);                                                 \
        x1 = _mm_##fn(x1, a);                                                       \
        x2 = _mm_##fn(x2, a);                                                       \
        COPY_XMM_TO_IMM(x1, x2, ret);                                               \
        return (ret);                                                               \
    }

#warning "Using SSE2 to perform AVX2 bitshift ops"
AVX2_BITOP_USING_SSE2(slli_epi32)
AVX2_BITOP_USING_SSE2(srli_epi32)

#define AVX2_INTOP_USING_SSE2(fn)                                                    \
    MEGDNN_ATTRIBUTE_TARGET("avx2") static inline v8si _mm256_##fn(v8si x, v8si y) { \
        /* use SSE2 instructions to perform the AVX2 integer operation */            \
        v4si x1, x2;                                                                 \
        v4si y1, y2;                                                                 \
        v8si ret;                                                                    \
        COPY_IMM_TO_XMM(x, x1, x2);                                                  \
        COPY_IMM_TO_XMM(y, y1, y2);                                                  \
        x1 = _mm_##fn(x1, y1);                                                       \
        x2 = _mm_##fn(x2, y2);                                                       \
        COPY_XMM_TO_IMM(x1, x2, ret);                                                \
        return (ret);                                                                \
    }

#warning "Using SSE2 to perform AVX2 integer ops"
AVX2_INTOP_USING_SSE2(and_si128)
AVX2_INTOP_USING_SSE2(andnot_si128)
AVX2_INTOP_USING_SSE2(cmpeq_epi32)
AVX2_INTOP_USING_SSE2(sub_epi32)
AVX2_INTOP_USING_SSE2(add_epi32)

#endif /* __AVX2__ */

/* natural logarithm computed for 8 simultaneous float
   return NaN for x <= 0
*/
v8sf log256_ps(v8sf x) {
    v8si imm0;
    v8sf one = *(v8sf*)_ps256_1;

    // v8sf invalid_mask = _mm256_cmple_ps(x, _mm256_setzero_ps());
    v8sf invalid_mask = _mm256_cmp_ps(x, _mm256_setzero_ps(), _CMP_LE_OS);

    x = _mm256_max_ps(x, *(v8sf*)_ps256_min_norm_pos); /* cut off denormalized stuff */

    // can be done with AVX2
    imm0 = _mm256_srli_epi32(_mm256_castps_si256(x), 23);

    /* keep only the fractional part */
    x = _mm256_and_ps(x, *(v8sf*)_ps256_inv_mant_mask);
    x = _mm256_or_ps(x, *(v8sf*)_ps256_0p5);

    // this is again another AVX2 instruction
    imm0 = _mm256_sub_epi32(imm0, *(v8si*)_pi32_256_0x7f);
    v8sf e = _mm256_cvtepi32_ps(imm0);

    e = _mm256_add_ps(e, one);

    /* part2:
       if( x < SQRTHF ) {
         e -= 1;
         x = x + x - 1.0;
       } else { x = x - 1.0; }
    */
    // v8sf mask = _mm256_cmplt_ps(x, *(v8sf*)_ps256_cephes_SQRTHF);
    v8sf mask = _mm256_cmp_ps(x, *(v8sf*)_ps256_cephes_SQRTHF, _CMP_LT_OS);
    v8sf tmp = _mm256_and_ps(x, mask);
    x = _mm256_sub_ps(x, one);
    e = _mm256_sub_ps(e, _mm256_and_ps(one, mask));
    x = _mm256_add_ps(x, tmp);

    v8sf z = _mm256_mul_ps(x, x);

    v8sf y = *(v8sf*)_ps256_cephes_log_p0;
    y = _mm256_mul_ps(y, x);
    y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_log_p1);
    y = _mm256_mul_ps(y, x);
    y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_log_p2);
    y = _mm256_mul_ps(y, x);
    y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_log_p3);
    y = _mm256_mul_ps(y, x);
    y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_log_p4);
    y = _mm256_mul_ps(y, x);
    y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_log_p5);
    y = _mm256_mul_ps(y, x);
    y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_log_p6);
    y = _mm256_mul_ps(y, x);
    y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_log_p7);
    y = _mm256_mul_ps(y, x);
    y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_log_p8);
    y = _mm256_mul_ps(y, x);

    y = _mm256_mul_ps(y, z);

    tmp = _mm256_mul_ps(e, *(v8sf*)_ps256_cephes_log_q1);
    y = _mm256_add_ps(y, tmp);

    tmp = _mm256_mul_ps(z, *(v8sf*)_ps256_0p5);
    y = _mm256_sub_ps(y, tmp);

    tmp = _mm256_mul_ps(e, *(v8sf*)_ps256_cephes_log_q2);
    x = _mm256_add_ps(x, y);
    x = _mm256_add_ps(x, tmp);
    x = _mm256_or_ps(x, invalid_mask);  // negative arg will be NAN
    return x;
}

_PS256_CONST(exp_hi, 88.3762626647949f);
_PS256_CONST(exp_lo, -88.3762626647949f);

_PS256_CONST(cephes_LOG2EF, 1.44269504088896341);
_PS256_CONST(cephes_exp_C1, 0.693359375);
_PS256_CONST(cephes_exp_C2, -2.12194440e-4);

_PS256_CONST(cephes_exp_p0, 1.9875691500E-4);
_PS256_CONST(cephes_exp_p1, 1.3981999507E-3);
_PS256_CONST(cephes_exp_p2, 8.3334519073E-3);
_PS256_CONST(cephes_exp_p3, 4.1665795894E-2);
_PS256_CONST(cephes_exp_p4, 1.6666665459E-1);
_PS256_CONST(cephes_exp_p5, 5.0000001201E-1);

v8sf exp256_ps(v8sf x) {
    v8sf tmp = _mm256_setzero_ps(), fx;
    v8si imm0;
    v8sf one = *(v8sf*)_ps256_1;

    x = _mm256_min_ps(x, *(v8sf*)_ps256_exp_hi);
    x = _mm256_max_ps(x, *(v8sf*)_ps256_exp_lo);

    /* express exp(x) as exp(g + n*log(2)) */
    fx = _mm256_mul_ps(x, *(v8sf*)_ps256_cephes_LOG2EF);
    fx = _mm256_add_ps(fx, *(v8sf*)_ps256_0p5);

    /* how to perform a floorf with SSE: just below */
    // imm0 = _mm256_cvttps_epi32(fx);
    // tmp  = _mm256_cvtepi32_ps(imm0);

    tmp = _mm256_floor_ps(fx);

    /* if greater, subtract 1 */
    // v8sf mask = _mm256_cmpgt_ps(tmp, fx);
    v8sf mask = _mm256_cmp_ps(tmp, fx, _CMP_GT_OS);
    mask = _mm256_and_ps(mask, one);
    fx = _mm256_sub_ps(tmp, mask);

    tmp = _mm256_mul_ps(fx, *(v8sf*)_ps256_cephes_exp_C1);
    v8sf z = _mm256_mul_ps(fx, *(v8sf*)_ps256_cephes_exp_C2);
    x = _mm256_sub_ps(x, tmp);
    x = _mm256_sub_ps(x, z);

    z = _mm256_mul_ps(x, x);

    v8sf y = *(v8sf*)_ps256_cephes_exp_p0;
    y = _mm256_mul_ps(y, x);
    y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_exp_p1);
    y = _mm256_mul_ps(y, x);
    y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_exp_p2);
    y = _mm256_mul_ps(y, x);
    y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_exp_p3);
    y = _mm256_mul_ps(y, x);
    y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_exp_p4);
    y = _mm256_mul_ps(y, x);
    y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_exp_p5);
    y = _mm256_mul_ps(y, z);
    y = _mm256_add_ps(y, x);
    y = _mm256_add_ps(y, one);

    /* build 2^n */
    imm0 = _mm256_cvttps_epi32(fx);
    // another two AVX2 instructions
    imm0 = _mm256_add_epi32(imm0, *(v8si*)_pi32_256_0x7f);
    imm0 = _mm256_slli_epi32(imm0, 23);
    v8sf pow2n = _mm256_castsi256_ps(imm0);
    y = _mm256_mul_ps(y, pow2n);
    return y;
}

_PS256_CONST(minus_cephes_DP1, -0.78515625);
_PS256_CONST(minus_cephes_DP2, -2.4187564849853515625e-4);
_PS256_CONST(minus_cephes_DP3, -3.77489497744594108e-8);
_PS256_CONST(sincof_p0, -1.9515295891E-4);
_PS256_CONST(sincof_p1, 8.3321608736E-3);
_PS256_CONST(sincof_p2, -1.6666654611E-1);
_PS256_CONST(coscof_p0, 2.443315711809948E-005);
_PS256_CONST(coscof_p1, -1.388731625493765E-003);
_PS256_CONST(coscof_p2, 4.166664568298827E-002);
_PS256_CONST(cephes_FOPI, 1.27323954473516);  // 4 / M_PI

/* evaluation of 8 sines at onces using AVX intrisics

   The code is the exact rewriting of the cephes sinf function.
   Precision is excellent as long as x < 8192 (I did not bother to
   take into account the special handling they have for greater values
   -- it does not return garbage for arguments over 8192, though, but
   the extra precision is missing).

   Note that it is such that sinf((float)M_PI) = 8.74e-8, which is the
   surprising but correct result.

*/
v8sf sin256_ps(v8sf x) {  // any x
    v8sf xmm1, xmm2 = _mm256_setzero_ps(), xmm3, sign_bit, y;
    v8si imm0, imm2;

#ifndef __AVX2__
    v4si imm0_1, imm0_2;
    v4si imm2_1, imm2_2;
#endif

    sign_bit = x;
    /* take the absolute value */
    x = _mm256_and_ps(x, *(v8sf*)_ps256_inv_sign_mask);
    /* extract the sign bit (upper one) */
    sign_bit = _mm256_and_ps(sign_bit, *(v8sf*)_ps256_sign_mask);

    /* scale by 4/Pi */
    y = _mm256_mul_ps(x, *(v8sf*)_ps256_cephes_FOPI);

    /*
      Here we start a series of integer operations, which are in the
      realm of AVX2.
      If we don't have AVX, let's perform them using SSE2 directives
    */

#ifdef __AVX2__
    /* store the integer part of y in mm0 */
    imm2 = _mm256_cvttps_epi32(y);
    /* j=(j+1) & (~1) (see the cephes sources) */
    // another two AVX2 instruction
    imm2 = _mm256_add_epi32(imm2, *(v8si*)_pi32_256_1);
    imm2 = _mm256_and_si256(imm2, *(v8si*)_pi32_256_inv1);
    y = _mm256_cvtepi32_ps(imm2);

    /* get the swap sign flag */
    imm0 = _mm256_and_si256(imm2, *(v8si*)_pi32_256_4);
    imm0 = _mm256_slli_epi32(imm0, 29);
    /* get the polynom selection mask
       there is one polynom for 0 <= x <= Pi/4
       and another one for Pi/4<x<=Pi/2

       Both branches will be computed.
    */
    imm2 = _mm256_and_si256(imm2, *(v8si*)_pi32_256_2);
    imm2 = _mm256_cmpeq_epi32(imm2, *(v8si*)_pi32_256_0);
#else
    /* we use SSE2 routines to perform the integer ops */
    COPY_IMM_TO_XMM(_mm256_cvttps_epi32(y), imm2_1, imm2_2);

    imm2_1 = _mm_add_epi32(imm2_1, *(v4si*)_pi32avx_1);
    imm2_2 = _mm_add_epi32(imm2_2, *(v4si*)_pi32avx_1);

    imm2_1 = _mm_and_si128(imm2_1, *(v4si*)_pi32avx_inv1);
    imm2_2 = _mm_and_si128(imm2_2, *(v4si*)_pi32avx_inv1);

    COPY_XMM_TO_IMM(imm2_1, imm2_2, imm2);
    y = _mm256_cvtepi32_ps(imm2);

    imm0_1 = _mm_and_si128(imm2_1, *(v4si*)_pi32avx_4);
    imm0_2 = _mm_and_si128(imm2_2, *(v4si*)_pi32avx_4);

    imm0_1 = _mm_slli_epi32(imm0_1, 29);
    imm0_2 = _mm_slli_epi32(imm0_2, 29);

    COPY_XMM_TO_IMM(imm0_1, imm0_2, imm0);

    imm2_1 = _mm_and_si128(imm2_1, *(v4si*)_pi32avx_2);
    imm2_2 = _mm_and_si128(imm2_2, *(v4si*)_pi32avx_2);

    imm2_1 = _mm_cmpeq_epi32(imm2_1, _mm_setzero_si128());
    imm2_2 = _mm_cmpeq_epi32(imm2_2, _mm_setzero_si128());

    COPY_XMM_TO_IMM(imm2_1, imm2_2, imm2);
#endif

    v8sf swap_sign_bit = _mm256_castsi256_ps(imm0);
    v8sf poly_mask = _mm256_castsi256_ps(imm2);
    sign_bit = _mm256_xor_ps(sign_bit, swap_sign_bit);

    /* The magic pass: "Extended precision modular arithmetic"
       x = ((x - y * DP1) - y * DP2) - y * DP3; */
    xmm1 = *(v8sf*)_ps256_minus_cephes_DP1;
    xmm2 = *(v8sf*)_ps256_minus_cephes_DP2;
    xmm3 = *(v8sf*)_ps256_minus_cephes_DP3;
    xmm1 = _mm256_mul_ps(y, xmm1);
    xmm2 = _mm256_mul_ps(y, xmm2);
    xmm3 = _mm256_mul_ps(y, xmm3);
    x = _mm256_add_ps(x, xmm1);
    x = _mm256_add_ps(x, xmm2);
    x = _mm256_add_ps(x, xmm3);

    /* Evaluate the first polynom  (0 <= x <= Pi/4) */
    y = *(v8sf*)_ps256_coscof_p0;
    v8sf z = _mm256_mul_ps(x, x);

    y = _mm256_mul_ps(y, z);
    y = _mm256_add_ps(y, *(v8sf*)_ps256_coscof_p1);
    y = _mm256_mul_ps(y, z);
    y = _mm256_add_ps(y, *(v8sf*)_ps256_coscof_p2);
    y = _mm256_mul_ps(y, z);
    y = _mm256_mul_ps(y, z);
    v8sf tmp = _mm256_mul_ps(z, *(v8sf*)_ps256_0p5);
    y = _mm256_sub_ps(y, tmp);
    y = _mm256_add_ps(y, *(v8sf*)_ps256_1);

    /* Evaluate the second polynom  (Pi/4 <= x <= 0) */

    v8sf y2 = *(v8sf*)_ps256_sincof_p0;
    y2 = _mm256_mul_ps(y2, z);
    y2 = _mm256_add_ps(y2, *(v8sf*)_ps256_sincof_p1);
    y2 = _mm256_mul_ps(y2, z);
    y2 = _mm256_add_ps(y2, *(v8sf*)_ps256_sincof_p2);
    y2 = _mm256_mul_ps(y2, z);
    y2 = _mm256_mul_ps(y2, x);
    y2 = _mm256_add_ps(y2, x);

    /* select the correct result from the two polynoms */
    xmm3 = poly_mask;
    y2 = _mm256_and_ps(xmm3, y2);  //, xmm3);
    y = _mm256_andnot_ps(xmm3, y);
    y = _mm256_add_ps(y, y2);
    /* update the sign */
    y = _mm256_xor_ps(y, sign_bit);

    return y;
}

/* almost the same as sin_ps */
v8sf cos256_ps(v8sf x) {  // any x
    v8sf xmm1, xmm2 = _mm256_setzero_ps(), xmm3, y;
    v8si imm0, imm2;

#ifndef __AVX2__
    v4si imm0_1, imm0_2;
    v4si imm2_1, imm2_2;
#endif

    /* take the absolute value */
    x = _mm256_and_ps(x, *(v8sf*)_ps256_inv_sign_mask);

    /* scale by 4/Pi */
    y = _mm256_mul_ps(x, *(v8sf*)_ps256_cephes_FOPI);

#ifdef __AVX2__
    /* store the integer part of y in mm0 */
    imm2 = _mm256_cvttps_epi32(y);
    /* j=(j+1) & (~1) (see the cephes sources) */
    imm2 = _mm256_add_epi32(imm2, *(v8si*)_pi32_256_1);
    imm2 = _mm256_and_si256(imm2, *(v8si*)_pi32_256_inv1);
    y = _mm256_cvtepi32_ps(imm2);
    imm2 = _mm256_sub_epi32(imm2, *(v8si*)_pi32_256_2);

    /* get the swap sign flag */
    imm0 = _mm256_andnot_si256(imm2, *(v8si*)_pi32_256_4);
    imm0 = _mm256_slli_epi32(imm0, 29);
    /* get the polynom selection mask */
    imm2 = _mm256_and_si256(imm2, *(v8si*)_pi32_256_2);
    imm2 = _mm256_cmpeq_epi32(imm2, *(v8si*)_pi32_256_0);
#else

    /* we use SSE2 routines to perform the integer ops */
    COPY_IMM_TO_XMM(_mm256_cvttps_epi32(y), imm2_1, imm2_2);

    imm2_1 = _mm_add_epi32(imm2_1, *(v4si*)_pi32avx_1);
    imm2_2 = _mm_add_epi32(imm2_2, *(v4si*)_pi32avx_1);

    imm2_1 = _mm_and_si128(imm2_1, *(v4si*)_pi32avx_inv1);
    imm2_2 = _mm_and_si128(imm2_2, *(v4si*)_pi32avx_inv1);

    COPY_XMM_TO_IMM(imm2_1, imm2_2, imm2);
    y = _mm256_cvtepi32_ps(imm2);

    imm2_1 = _mm_sub_epi32(imm2_1, *(v4si*)_pi32avx_2);
    imm2_2 = _mm_sub_epi32(imm2_2, *(v4si*)_pi32avx_2);

    imm0_1 = _mm_andnot_si128(imm2_1, *(v4si*)_pi32avx_4);
    imm0_2 = _mm_andnot_si128(imm2_2, *(v4si*)_pi32avx_4);

    imm0_1 = _mm_slli_epi32(imm0_1, 29);
    imm0_2 = _mm_slli_epi32(imm0_2, 29);

    COPY_XMM_TO_IMM(imm0_1, imm0_2, imm0);

    imm2_1 = _mm_and_si128(imm2_1, *(v4si*)_pi32avx_2);
    imm2_2 = _mm_and_si128(imm2_2, *(v4si*)_pi32avx_2);

    imm2_1 = _mm_cmpeq_epi32(imm2_1, _mm_setzero_si128());
    imm2_2 = _mm_cmpeq_epi32(imm2_2, _mm_setzero_si128());

    COPY_XMM_TO_IMM(imm2_1, imm2_2, imm2);
#endif

    v8sf sign_bit = _mm256_castsi256_ps(imm0);
    v8sf poly_mask = _mm256_castsi256_ps(imm2);

    /* The magic pass: "Extended precision modular arithmetic"
       x = ((x - y * DP1) - y * DP2) - y * DP3; */
    xmm1 = *(v8sf*)_ps256_minus_cephes_DP1;
    xmm2 = *(v8sf*)_ps256_minus_cephes_DP2;
    xmm3 = *(v8sf*)_ps256_minus_cephes_DP3;
    xmm1 = _mm256_mul_ps(y, xmm1);
    xmm2 = _mm256_mul_ps(y, xmm2);
    xmm3 = _mm256_mul_ps(y, xmm3);
    x = _mm256_add_ps(x, xmm1);
    x = _mm256_add_ps(x, xmm2);
    x = _mm256_add_ps(x, xmm3);

    /* Evaluate the first polynom  (0 <= x <= Pi/4) */
    y = *(v8sf*)_ps256_coscof_p0;
    v8sf z = _mm256_mul_ps(x, x);

    y = _mm256_mul_ps(y, z);
    y = _mm256_add_ps(y, *(v8sf*)_ps256_coscof_p1);
    y = _mm256_mul_ps(y, z);
    y = _mm256_add_ps(y, *(v8sf*)_ps256_coscof_p2);
    y = _mm256_mul_ps(y, z);
    y = _mm256_mul_ps(y, z);
    v8sf tmp = _mm256_mul_ps(z, *(v8sf*)_ps256_0p5);
    y = _mm256_sub_ps(y, tmp);
    y = _mm256_add_ps(y, *(v8sf*)_ps256_1);

    /* Evaluate the second polynom  (Pi/4 <= x <= 0) */

    v8sf y2 = *(v8sf*)_ps256_sincof_p0;
    y2 = _mm256_mul_ps(y2, z);
    y2 = _mm256_add_ps(y2, *(v8sf*)_ps256_sincof_p1);
    y2 = _mm256_mul_ps(y2, z);
    y2 = _mm256_add_ps(y2, *(v8sf*)_ps256_sincof_p2);
    y2 = _mm256_mul_ps(y2, z);
    y2 = _mm256_mul_ps(y2, x);
    y2 = _mm256_add_ps(y2, x);

    /* select the correct result from the two polynoms */
    xmm3 = poly_mask;
    y2 = _mm256_and_ps(xmm3, y2);  //, xmm3);
    y = _mm256_andnot_ps(xmm3, y);
    y = _mm256_add_ps(y, y2);
    /* update the sign */
    y = _mm256_xor_ps(y, sign_bit);

    return y;
}

/* since sin256_ps and cos256_ps are almost identical, sincos256_ps could replace both
   of them.. it is almost as fast, and gives you a free cosine with your sine */
void sincos256_ps(v8sf x, v8sf* s, v8sf* c) {
    v8sf xmm1, xmm2, xmm3 = _mm256_setzero_ps(), sign_bit_sin, y;
    v8si imm0, imm2, imm4;

#ifndef __AVX2__
    v4si imm0_1, imm0_2;
    v4si imm2_1, imm2_2;
    v4si imm4_1, imm4_2;
#endif

    sign_bit_sin = x;
    /* take the absolute value */
    x = _mm256_and_ps(x, *(v8sf*)_ps256_inv_sign_mask);
    /* extract the sign bit (upper one) */
    sign_bit_sin = _mm256_and_ps(sign_bit_sin, *(v8sf*)_ps256_sign_mask);

    /* scale by 4/Pi */
    y = _mm256_mul_ps(x, *(v8sf*)_ps256_cephes_FOPI);

#ifdef __AVX2__
    /* store the integer part of y in imm2 */
    imm2 = _mm256_cvttps_epi32(y);

    /* j=(j+1) & (~1) (see the cephes sources) */
    imm2 = _mm256_add_epi32(imm2, *(v8si*)_pi32_256_1);
    imm2 = _mm256_and_si256(imm2, *(v8si*)_pi32_256_inv1);

    y = _mm256_cvtepi32_ps(imm2);
    imm4 = imm2;

    /* get the swap sign flag for the sine */
    imm0 = _mm256_and_si256(imm2, *(v8si*)_pi32_256_4);
    imm0 = _mm256_slli_epi32(imm0, 29);
    // v8sf swap_sign_bit_sin = _mm256_castsi256_ps(imm0);

    /* get the polynom selection mask for the sine*/
    imm2 = _mm256_and_si256(imm2, *(v8si*)_pi32_256_2);
    imm2 = _mm256_cmpeq_epi32(imm2, *(v8si*)_pi32_256_0);
    // v8sf poly_mask = _mm256_castsi256_ps(imm2);
#else
    /* we use SSE2 routines to perform the integer ops */
    COPY_IMM_TO_XMM(_mm256_cvttps_epi32(y), imm2_1, imm2_2);

    imm2_1 = _mm_add_epi32(imm2_1, *(v4si*)_pi32avx_1);
    imm2_2 = _mm_add_epi32(imm2_2, *(v4si*)_pi32avx_1);

    imm2_1 = _mm_and_si128(imm2_1, *(v4si*)_pi32avx_inv1);
    imm2_2 = _mm_and_si128(imm2_2, *(v4si*)_pi32avx_inv1);

    COPY_XMM_TO_IMM(imm2_1, imm2_2, imm2);
    y = _mm256_cvtepi32_ps(imm2);

    imm4_1 = imm2_1;
    imm4_2 = imm2_2;

    imm0_1 = _mm_and_si128(imm2_1, *(v4si*)_pi32avx_4);
    imm0_2 = _mm_and_si128(imm2_2, *(v4si*)_pi32avx_4);

    imm0_1 = _mm_slli_epi32(imm0_1, 29);
    imm0_2 = _mm_slli_epi32(imm0_2, 29);

    COPY_XMM_TO_IMM(imm0_1, imm0_2, imm0);

    imm2_1 = _mm_and_si128(imm2_1, *(v4si*)_pi32avx_2);
    imm2_2 = _mm_and_si128(imm2_2, *(v4si*)_pi32avx_2);

    imm2_1 = _mm_cmpeq_epi32(imm2_1, _mm_setzero_si128());
    imm2_2 = _mm_cmpeq_epi32(imm2_2, _mm_setzero_si128());

    COPY_XMM_TO_IMM(imm2_1, imm2_2, imm2);
#endif
    v8sf swap_sign_bit_sin = _mm256_castsi256_ps(imm0);
    v8sf poly_mask = _mm256_castsi256_ps(imm2);

    /* The magic pass: "Extended precision modular arithmetic"
       x = ((x - y * DP1) - y * DP2) - y * DP3; */
    xmm1 = *(v8sf*)_ps256_minus_cephes_DP1;
    xmm2 = *(v8sf*)_ps256_minus_cephes_DP2;
    xmm3 = *(v8sf*)_ps256_minus_cephes_DP3;
    xmm1 = _mm256_mul_ps(y, xmm1);
    xmm2 = _mm256_mul_ps(y, xmm2);
    xmm3 = _mm256_mul_ps(y, xmm3);
    x = _mm256_add_ps(x, xmm1);
    x = _mm256_add_ps(x, xmm2);
    x = _mm256_add_ps(x, xmm3);

#ifdef __AVX2__
    imm4 = _mm256_sub_epi32(imm4, *(v8si*)_pi32_256_2);
    imm4 = _mm256_andnot_si256(imm4, *(v8si*)_pi32_256_4);
    imm4 = _mm256_slli_epi32(imm4, 29);
#else
    imm4_1 = _mm_sub_epi32(imm4_1, *(v4si*)_pi32avx_2);
    imm4_2 = _mm_sub_epi32(imm4_2, *(v4si*)_pi32avx_2);

    imm4_1 = _mm_andnot_si128(imm4_1, *(v4si*)_pi32avx_4);
    imm4_2 = _mm_andnot_si128(imm4_2, *(v4si*)_pi32avx_4);

    imm4_1 = _mm_slli_epi32(imm4_1, 29);
    imm4_2 = _mm_slli_epi32(imm4_2, 29);

    COPY_XMM_TO_IMM(imm4_1, imm4_2, imm4);
#endif

    v8sf sign_bit_cos = _mm256_castsi256_ps(imm4);

    sign_bit_sin = _mm256_xor_ps(sign_bit_sin, swap_sign_bit_sin);

    /* Evaluate the first polynom  (0 <= x <= Pi/4) */
    v8sf z = _mm256_mul_ps(x, x);
    y = *(v8sf*)_ps256_coscof_p0;

    y = _mm256_mul_ps(y, z);
    y = _mm256_add_ps(y, *(v8sf*)_ps256_coscof_p1);
    y = _mm256_mul_ps(y, z);
    y = _mm256_add_ps(y, *(v8sf*)_ps256_coscof_p2);
    y = _mm256_mul_ps(y, z);
    y = _mm256_mul_ps(y, z);
    v8sf tmp = _mm256_mul_ps(z, *(v8sf*)_ps256_0p5);
    y = _mm256_sub_ps(y, tmp);
    y = _mm256_add_ps(y, *(v8sf*)_ps256_1);

    /* Evaluate the second polynom  (Pi/4 <= x <= 0) */

    v8sf y2 = *(v8sf*)_ps256_sincof_p0;
    y2 = _mm256_mul_ps(y2, z);
    y2 = _mm256_add_ps(y2, *(v8sf*)_ps256_sincof_p1);
    y2 = _mm256_mul_ps(y2, z);
    y2 = _mm256_add_ps(y2, *(v8sf*)_ps256_sincof_p2);
    y2 = _mm256_mul_ps(y2, z);
    y2 = _mm256_mul_ps(y2, x);
    y2 = _mm256_add_ps(y2, x);

    /* select the correct result from the two polynoms */
    xmm3 = poly_mask;
    v8sf ysin2 = _mm256_and_ps(xmm3, y2);
    v8sf ysin1 = _mm256_andnot_ps(xmm3, y);
    y2 = _mm256_sub_ps(y2, ysin2);
    y = _mm256_sub_ps(y, ysin1);

    xmm1 = _mm256_add_ps(ysin1, ysin2);
    xmm2 = _mm256_add_ps(y, y2);

    /* update the sign */
    *s = _mm256_xor_ps(xmm1, sign_bit_sin);
    *c = _mm256_xor_ps(xmm2, sign_bit_cos);
}

#undef __AVX2__
#undef ALIGN32_BEG
#undef ALIGN32_END
}  // namespace detail
}  // namespace x86
}  // namespace megdnn

// vim: ft=cpp syntax=cpp.doxygen
